import numpy as np
import re
import copy
import sys
import pandas as pd
import itertools
import networkx as nx
import networkx.algorithms.graph_hashing as graph_hashing
import networkx.algorithms.isomorphism as isomorph
from collections import defaultdict
from .molgraph import n2formula
from scipy.spatial import cKDTree
from .transformer import *
from .parameter import Parameter
from .structure import Structure
from .structurereader import conver_structure
from .structurewriter import  write_structure

'''functions delt with fragment'''

__all__ = [
    'frag_by_breakbond',
    'frag_by_trimatom',
    'connect_frag',
    'cap_with'
]

def print_frag(frag_list):
    '''report all frag in frag list to screen'''
    if len(frag_list) == 1:
        f = frag_list[0]
        print('Fragment {:s} with {:d} sites ({:s}) obtained by breaking bond'
              .format(n2formula(f['graph'],f['sn']),len(f['site']),','.join(f['site'].keys())))
        for k, v in f['site'].items():
            print('{:s} cut:{:s}({:s}) rm:{:s}'
                  .format(k,'-'.join([str(i) for i in v]),'-'.join([f['graph'].nodes[i]['elem'] for i in v])
                          ,n2formula(f['graph'],f['offcut'][k])))
    if len(frag_list) > 1:
        for idx,f in enumerate(frag_list):
            print('index {:d}: Fragment {:s} with {:d} sites ({:s}) obtained by breaking bond'
                  .format(idx, n2formula(f['graph'], f['sn']), len(f['site']), ','.join(f['site'].keys())))
            for k, v in f['site'].items():
                print(' '*(len('index {:d}'.format(idx))+2)+'{:s} cut:{:s}({:s}) rm:{:s}'
                      .format(k, '-'.join([str(i) for i in v]), '-'.join([f['graph'].nodes[i]['elem'] for i in v])
                              , n2formula(f['graph'], f['offcut'][k])))

def frag_by_breakbond(mol, size=0, nfrag=0, **kwargs):
    '''
    either st or mol should be specified
    st is a structure object and its first molecule will be used
    input a molecule default dict, output a molecule default dict with site keys
    breakbond is a string with the format A1=15-23,A2=24-56
    where A1 and A2 are site labels, 15 and 24 are inner atoms (atoms of this fragments)
    23 and 56 are outer atoms (atoms that will be in other fragments, or will be replaced by cap groups)
    It is advise to use at least 2 character as site label, where the first is the frag label
    and the second is connection site label
    the site label is used to connect different fragment with same site label
    this function will return a default dict of fragment
    graph is the graph of whole molecule
    sn is the serial number of the main fragment
    offcut is a dict with site name as key and offcut sn as value
    site is a dict with site name as key and cutting bond as value
    if more than one frag is obtained. all frags will be obtained
    The graph of the fragment is the copy of the molecule
    '''
    site = {k:tuple(map(int,v.split('-'))) for k,v in kwargs.items()}
    site2outer = {k:v[1] for k,v in site.items()}
    edges = [v for k,v in site.items()]
    inner_atoms = [i[0] for i in edges]
    G = copy.deepcopy(mol['graph'].copy())
    for e in edges:
        if G.has_edge(*e):
            G.remove_edges_from([e])
        else:
            print('Warning! edge {:s} not present in the graph'.format('-'.join([str(i) for i in e])))
    main_frag = []
    offcut = {}
    all_frag = []
    for frag in nx.connected_components(G):
        if any([i in frag for i in inner_atoms]): #and all([i not in frag for i in outer_atoms]):
            main_frag.append(frag)
        for k,v in site2outer.items():
            if v in frag:
                if frag in main_frag:
                    print('Error! a fragment contain both inner and outer atoms')
                    sys.exit()
                offcut[k] = frag
    if len(main_frag) == 0:
        print('Error! Could not extract fragment by breaking bonds {:s}\n'
              'Note that the atom sn order in bonds should be inner-outer'
              .format(','.join([v for k,v in kwargs.items()])))
        sys.exit()
    main_frag = sorted(main_frag,key=lambda x: len(x),reverse=True)
    if len(main_frag) > 1:
        print('Warning! more than one ({:s}) fragment generated by breaking bond.'
              .format(','.join([str(len(i)) for i in main_frag])))
    if size > 0:
        main_frag = [f for f in main_frag if len(f) > size]
    if nfrag >= 1:
        main_frag = [f for f in main_frag if len(f) > size][:nfrag]
    for idx,f in enumerate(main_frag):
        result_frag = defaultdict(str)
        result_frag['sn'] = f
        result_frag['formula'] = n2formula(mol['graph'],f)
        result_frag['site'] = {k:v for k,v in site.items() if v[0] in f} # site is based on old sn
        result_frag['offcut'] = {k:v for k,v in offcut.items() if k in site.keys()} # offcut is a dict with site label as key and offcut frag sn as value
        result_frag['graph'] = copy.deepcopy(mol['graph'].copy())
        result_frag['mol'] = mol
        all_frag.append(result_frag)
    print_frag(all_frag)
    return all_frag

def frag_by_trimatom(mol,keep='',remove='',size=0, nfrag=1, nohydrogen=True, trim_record={}):
    '''
    input a mol defaultdict
    will return a tuple of
    ((list of keep atoms),((cap bond atom pairs)))
    the fragment with number of atoms larger than size will be kept
    if nfrag > 1, the largest nfrag will be kept
    note that the molecule do not need to be unwrapped for it to be trimmed properly
    trim_record is a dict recorded previous trim profile(not implemented yet)
    remove condition has higher priority than keep
    the kee/remove format  is
    keep="elem=C,conju_state=fc;elem=H,conju_state=hc"
    '''

    G = mol['graph']
    # trim_key = mol['elem']+mol['degree'] + mol['hash'] + keep + remove + str(size) + str(nfrag)
    # if trim_key in trim_record:
    #     idx2sn = {idx:sn for idx, sn in enumerate(sorted(G.nodes()))}
    #     keep_idx,cap_idx = trim_record[trim_key]
    #     keep_atom = [idx2sn[idx] for idx in keep_idx]
    #     cap_atom = [(idx2sn[i1],idx2sn[i2]) for i1,i2 in cap_idx]
    #     return keep_atom,cap_atom,trim_record
    ops = {"=": (lambda x, y: x == y),
           "==": (lambda x, y: x == y),
           ">": (lambda x, y: float(x) > float(y)),
           "<": (lambda x, y: float(x) < float(y)),
           ">=": (lambda x, y: float(x) >= float(y)),
           "<=": (lambda x, y: float(x) >= float(y)),
           }
    remove_atom = []
    keep_atom = []
    break_bond = []
    site_dict = {}
    cap_bond = []
    # parser conditions
    if len(keep) > 0:
        keep_list = []
        for con in keep.split(';'):
            k = [tuple(re.split('(==|>=|<=|=|>|<)', i)) for i in con.split(',')]
            keep_list.append(k)
        for atom_sn in G.nodes():
            atom = G.nodes[atom_sn]
            for con in keep_list:
                if all(ops[o](atom[k], v) for k, o, v in con):
                    keep_atom.append(atom_sn)
    if len(remove) > 0:
        exclude_list = []
        for con in remove.split(';'):
            e = [tuple(re.split('(==|>=|<=|=|>|<)', i)) for i in con.split(',')]
            exclude_list.append(e)
        for atom_sn in G.nodes():
            atom = G.nodes[atom_sn]
            for con in exclude_list:
                if all(ops[o](atom[k], v) for k, o, v in con):
                    remove_atom.append(atom_sn)
    if len(keep_atom) > 0:
        keep_atom = [i for i in keep_atom if i not in remove_atom ]
    elif len(remove_atom) > 0:
        keep_atom = [i for i in G.nodes() if i not in remove_atom ]
    # filter keep atom by size and nfrag
    # find boundary of keep_atom
    if len(keep_atom) > 0:
        bn = nx.node_boundary(G, keep_atom)
        for o in bn:
            if not (G.nodes[o]['elem'] == 'H' and nohydrogen):
                break_bond += [(i,o) for i in list(G.neighbors(o)) if i in keep_atom]
        site_idx = 1
        for b in break_bond:
            site_label = 'T{:02d}'.format(site_idx)
            site_dict[site_label]='-'.join([str(i) for i in b])
            site_idx += 1
    all_frags = frag_by_breakbond(mol,nfrag=nfrag,size=size,**site_dict)
    # remove_atom = [i for i in G.nodes() if i not in keep_atom]
    # frags = sorted([i for i in (nx.connected_components(G.subgraph(keep_atom)))],key=lambda x:len(x),reverse=True)
    # sn2idx  = {sn:idx for idx,sn in enumerate(sorted(G.nodes()))}
    # keep_idx = [sn2idx[sn] for sn in keep_atom]
    # trim_record[trim_key] = [keep_idx,cap_idx]
    return all_frags


def connect_frag(fragA,fragB,joinsite=''):
    '''connect two fragments by their site name and return a new frag
    the format for join is D1-A1,D2-A2
    if joinsite is not specified, then the function will try to find the
    match by following rules:
    1. A2 will match with B2 or B21 but not with B3 or B31
    2. A21 will match with B21 but not with B22 or B223
    3. site label A will match with any site label such as B1 B81 Baxy
    4. If fragB has two sites: A1 and B1 and fragA has one site A1, then A1-B1 has higher priority than A1-A1
    5. only one site pair will kept
    To specify join site manually, the format is
    joinsite=A1-B2,A3-B4
    in this case two join site could be specified
    when two frag are connected, only parent info of fragA will be kept
    '''
    param=Parameter()
    #parser join site
    valid_sp = []
    site_pair0= []
    site_pair1= []
    # first try to find site if not specified
    if not joinsite:
        for sa in fragA['site'].keys():
            for sb in fragB['site'].keys():
                compare_len = min(len(sa),len(sb))
                if sa[1:compare_len] == sb[1:compare_len] and sa[0] != sb[0]:
                    site_pair0.append((sa,sb))
                elif sa[1:compare_len] == sb[1:compare_len] and sa[0] == sb[0]:
                    site_pair1.append((sa, sb))
        if len(site_pair0+site_pair1) == 0:
            print('Error! Could not find site to join')
            sys.exit()
        else:
            site_pair = (site_pair0+site_pair1)[:1]
    else:
        site_pair = [tuple(i.split('-')) for i in joinsite.split(',')]

    # check if joinsite are in fragA and fragB
    for s in site_pair:
        if s[0] in fragA['site'] and s[1] in fragB['site']:
            valid_sp.append(s)
        elif s[0] in fragB['site'] and s[1] in fragA['site']:
                valid_sp.append((s[1],s[0]))
    if not valid_sp:
        print('Error! no valid site found in fragA or fragB by {:s}'
              .format(joinsite))
    else:
        print('site-pair to connect is {:s}'
              .format(','.join(['-'.join(i) for i in valid_sp])))
    site_pair = valid_sp

    # collect atom sn around the site. These atoms will be used to align the fragments
    annihilated_site = []
    for s in site_pair:
        sa, sb = s
        gb = fragB['graph']
        ga = fragA['graph']
        ia,oa = [ga.nodes[i]['sn'] for i in fragA['site'][sa]]
        ib,ob = [gb.nodes[i]['sn'] for i in fragB['site'][sb]]
        iia, ooa = [[ga.nodes[j] for j in ga.neighbors(i) if j not in fragA['site'][sa]]
                    for i in fragA['site'][sa]]
        iib, oob = [[gb.nodes[j] for j in gb.neighbors(i) if j not in fragB['site'][sb]]
                    for i in fragB['site'][sb]]
        iia = sorted(iia,key=lambda x:(ga.degree(x['sn']),x['atomnum']),reverse=True)
        ooa = sorted(ooa,key=lambda x:(ga.degree(x['sn']),x['atomnum']),reverse=True)
        iib = sorted(iib,key=lambda x:(gb.degree(x['sn']),x['atomnum']),reverse=True)
        oob = sorted(oob,key=lambda x:(gb.degree(x['sn']),x['atomnum']),reverse=True)
        iia = [i['sn'] for i in iia]
        ooa = [i['sn'] for i in ooa]
        iib = [i['sn'] for i in iib]
        oob = [i['sn'] for i in oob]
        annihilated_site.append([(sa,ia,oa,iia,ooa),(sb,ib,ob,iib,oob)])


    # align fragB to fragA
    sa,sb = site_pair[0]
    gb = fragB['graph']
    ga = fragA['graph']
    ia,oa = [ga.nodes[i] for i in fragA['site'][sa]]
    ib,ob = [gb.nodes[i] for i in fragB['site'][sb]]
    iia,ooa = [[ga.nodes[j] for j in ga.neighbors(i) if j not in fragA['site'][sa]]
               for i in fragA['site'][sa]]
    iib,oob = [[gb.nodes[j] for j in gb.neighbors(i) if j not in fragB['site'][sb]]
               for i in fragB['site'][sb]]

    # first move out atom of fragB to inner atom of fragA
    trans_mat = translate_to_align(ob['coord'],ia['coord'])

    # secord rotate inner atom of fragB to align to out atom of fragA
    ib_coord = np.append(ib['coord'],1) @ trans_mat
    rot_mat = rotate_to_align(ib_coord,oa['coord'],ia['coord'])

    # third calculate bond length delta and move frag B along the bond
    delta_bl = param.covr[ia['elem']][0] - param.covr[ob['elem']][0]
    trans_mat2 = translate_along(ia['coord'],oa['coord'],delta_bl)
    # sort iia ooa iib oob by degree and atomnum, the larger degree and atomnum
    # will be sorted to head and have larger priority
    if len(site_pair) == 1:
        iia = sorted(iia,key=lambda x:(ga.degree(x['sn']),x['atomnum']),reverse=True)
        ooa = sorted(ooa,key=lambda x:(ga.degree(x['sn']),x['atomnum']),reverse=True)
        iib = sorted(iib,key=lambda x:(gb.degree(x['sn']),x['atomnum']),reverse=True)
        oob = sorted(oob,key=lambda x:(gb.degree(x['sn']),x['atomnum']),reverse=True)

        # compare iia with oob and the ooa and iib to calculate score
        score_iaob = 0
        score_iboa = 0
        if len(oob) == 0:
            score_iaob += 10000
        if len(iia) != len(oob):
            score_iaob += 1000
        for p in zip(iia,oob):
            if ga.degree(p[0]['sn']) != gb.degree(p[1]['sn']):
                score_iaob += 100
            if p[0]['atomnum'] != p[1]['atomnum']:
                score_iaob += 10
        if len(ooa) == 0:
            score_iboa += 10000
        if len(iib) != len(ooa):
            score_iboa += 1000
        for p in zip(iib,ooa):
            if gb.degree(p[0]['sn']) != ga.degree(p[1]['sn']):
                score_iboa += 10
            if p[0]['atomnum'] != p[1]['atomnum']:
                score_iboa += 10
        B= np.append(ib['coord'],1)@trans_mat@rot_mat@trans_mat2
        C = ia['coord']
        if score_iboa >= 10000 and score_iaob >= 10000:
            # no ooa and oob atoms, rotate head of iia to eclipse with head of iib
            A = np.append(iib[0]['coord'],1)@trans_mat@rot_mat@trans_mat2
            D = iia[0]['coord']
            torsion_mat = rotate_dihedral_to(A,B,C,D,angle=180)
        elif score_iaob <= score_iboa:
            # rotate head of oob to align with head of iia
            A = np.append(oob[0]['coord'],1)@trans_mat@rot_mat@trans_mat2
            D = iia[0]['coord']
            torsion_mat = rotate_dihedral_to(A,B,C,D,angle=0)
        elif score_iaob > score_iboa:
            # rotate head of iib to align with head of ooa
            A = np.append(iib[0]['coord'],1)@trans_mat@rot_mat@trans_mat2
            D = ooa[0]['coord']
            torsion_mat = rotate_dihedral_to(A,B,C,D,angle=0)
    elif len(site_pair) > 1:
        # if their are two sites to join
        # rotate dihedral of first site to make the atoms of second site as close as possible
        s1a, s1b = site_pair[1]
        i1a, o1a = [ga.nodes[i] for i in fragA['site'][s1a]]
        i1b, o1b = [gb.nodes[i] for i in fragB['site'][s1b]]
        A = np.append(o1b['coord'],1)@trans_mat@rot_mat@trans_mat2
        B = np.append(ib['coord'],1)@trans_mat@rot_mat@trans_mat2
        C = ia['coord']
        D = i1a['coord']
        torsion_mat = rotate_dihedral_to(A, B, C, D, angle=0)

    # update coords in frag b
    coords= nx.get_node_attributes(gb,'coord')
    sn = coords.keys()
    coords = np.array(list(coords.values()))
    coords = np.hstack((coords,np.ones((coords.shape[0],1))))
    updated_coords = coords @ trans_mat @ rot_mat @ trans_mat2 @ torsion_mat
    updated_coords = updated_coords[:,:3]
    nga = copy.deepcopy(ga)
    ngb = copy.deepcopy(gb)
    nx.set_node_attributes(ngb,{k:{'coord':list(v)} for k,v in zip(sn,updated_coords)})
    # reset sn mapping for connected frangments
    # the order is fragA - fragB - siteofA - siteof B
    # the site that will be removed will not in the mapping
    mappingA = {}
    mappingB = {}
    for i, sn in enumerate(sorted(fragA['sn'])):
        mappingA[sn] = i + 1
    base_num = len(fragA['sn'])
    for i, sn in enumerate(sorted(fragB['sn'])):
        mappingB[sn] = i + 1 + base_num
    base_num += len(fragB['sn'])
    for site,sn in fragA['offcut'].items():
        if site not in [p[0] for p in site_pair]:
            for i,n in enumerate(sorted(sn)):
                mappingA[n] = i + 1 + base_num
            base_num += len(sn)
    for site,sn in fragB['offcut'].items():
        if site not in [p[1] for p in site_pair]:
            for i,n in enumerate(sorted(sn)):
                mappingB[n] = i + 1 + base_num
            base_num += len(sn)
    # remove site nodes from graph
    for p in site_pair:
        nga.remove_nodes_from([i for i in fragA['offcut'][p[0]]])
        ngb.remove_nodes_from([i for i in fragB['offcut'][p[1]]])
    # relable node by new sn
    nga =nx.relabel_nodes(nga,mappingA,copy=True)
    ngb =nx.relabel_nodes(ngb,mappingB,copy=True)
    nx.set_node_attributes(nga,{k:{'sn':k} for k in nga.nodes()})
    nx.set_node_attributes(ngb,{k:{'sn':k} for k in ngb.nodes()})
    # compose new graph
    fragAB = nx.compose(nga,ngb)
    # add edge that connect two graph
    for sa,sb in site_pair:
        edge = (mappingA[fragA['site'][sa][0]],mappingB[fragB['site'][sb][0]])
        fragAB.add_edge(*edge)
    # generate new connected frag dictionary
    new_frag = defaultdict(str)
    new_frag['sn'] = [mappingA[i] for i in fragA['sn']] + [mappingB[i] for i in fragB['sn']]
    site_left_A = {k:v for k,v in fragA['site'].items() if k not in [p[0] for p in site_pair]}
    site_left_B = {k:v for k,v in fragB['site'].items() if k not in [p[1] for p in site_pair]}
    new_siteA = {k:tuple([mappingA[i] for i in v]) for k,v in site_left_A.items()}
    new_siteB = {k:tuple([mappingB[i] for i in v]) for k,v in site_left_B.items()}
    new_siteA.update(new_siteB)
    new_frag['site'] = new_siteA  # site is based
    new_siteA_sn = {k:tuple([mappingA[i] for i in fragA['offcut'][k]]) for k in site_left_A.keys()}
    new_siteB_sn = {k:tuple([mappingB[i] for i in fragB['offcut'][k]]) for k in site_left_B.keys()}
    new_siteA_sn.update(new_siteB_sn)
    new_frag['offcut'] = new_siteA_sn
    new_frag['graph'] = fragAB
    new_frag['mol'] = fragA['mol']
    # st=self.graph2struct(fragAB)
    # write_structure(st,basename='DAtest',ext='.mol2')
    return new_frag,mappingA,mappingB,annihilated_site


    # print('Fragment {:s} with {:d} sites obtained by breaking bond'
    #       .format(self.n2formula(main_frag), len(site)))
    # for c in nx.connected_components(fragAB):
    #     print(c)
    # print(nx.is_connected(fragAB))
    # for c in nx.connected_components(fragAB):
    #     print(c)


def cap_with(frag,elem='F',capfrag=''):
    '''cap with elem atom and generate structure object with graph properties
       input a fragment default dict
       all the open site of this frag will be capped
       and the elems to be capped
       note that the coords  need to be unwrapped
    '''
    param = Parameter()
    # s.cell_vect = mol['cell_vect']
    # keep_atom, cap_bond = keep_cap_tuple
    # print(cap_bond)
    G = frag['graph']
    cap_atoms = []
    cap_sn = []
    for k,v in frag['site'].items():
        if G.nodes[v[1]]['elem'] != elem:
            Hbond_len = param.covr[G.nodes[v[0]]['elem']][0] + param.covr[elem][0]
            current_len = G[v[0]][v[1]]['weight']
            ratio = Hbond_len / current_len
            cap_atom_coord = (np.array(G.nodes[v[1]]['coord']) - np.array(G.nodes[v[0]]['coord'])) * ratio \
                             + np.array(G.nodes[v[0]]['coord'])
            atom = defaultdict(str)
            atom['elem'] = elem
            atom['sn'] = G.nodes[v[1]]['sn']
            atom['atomnum'] = param._elem2an[elem]
            atom['coord'] = list(cap_atom_coord)
            atom['cap_atom'] = True
            if 'atomtype' in G.nodes[v[1]]:
                atom['atomtype'] = elem

            cap_atoms.append(atom)
        else:
            cap_atoms.append(G.nodes[v[1]])
        cap_sn.append(v[1])
    print('')
    subG = G.subgraph(list(frag['sn']) + cap_sn).copy()
    nx.set_node_attributes(subG, {i['sn']: i for i in cap_atoms})
    st = conver_structure(subG,'graph')
    st.complete_self()
    st.graph = subG
    return st




